# Purpose: Produce some graphs that show how integration
#  measures change over time for users that arrived in 2015-2016.

source("constants.R") # a set of constants

outcome_order <- c(
  "Avg. Native Friends", "Avg. % of Friends Natives",
  "Avg. % Produced DE", "Avg. % Consumed DE")

outcomes_SPLIT_A <- c("Avg. Native Friends", "Avg. % Produced DE")
outcomes_SPLIT_B <- c("Avg. % of Friends Natives", "Avg. % Consumed DE")


############################
### 1. Overall Measures  ###
############################

dat <- read_csv("input/pipeline/integration_over_time_input.csv")

dat_to_write <- dat %>%
  rename(
    `Avg. Native Friends`=avg_friendships_to_curr_country_natives,
    `Avg. % of Friends Natives`=avg_share_frnds_natives,
    `Avg. % Produced DE`=avg_share_produced_german,
    `Avg. % Consumed DE`=avg_share_consumed_german) %>%
  gather(measure, value, -quarters_since_DE) %>%
  mutate(measure=factor(measure, levels=outcome_order)) %>%
  filter(quarters_since_DE >= 0 & quarters_since_DE <= 12)

dat_to_write %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=measure)) +
  geom_point(size=3) +
  theme_bw() +
  facet_grid(rows = vars(measure), scales = "free") +
  scale_color_brewer(palette = "Set2") +
  theme(legend.position = "none") +
  labs(x="Quarters in Germany", y="Value") +
  scale_x_continuous(breaks=1:12)

write_csv(dat_to_write, "output/integration_2015_2016_cohort.csv")
ggsave("output/integration_2015_2016_cohort.png", last_plot(), height=6.5, width=6.5)



############################
### 2. Measures by Age  ###
############################

dat_age_at_arrival <- read_csv("input/pipeline/integration_over_time_age_at_arrival.csv")

dat_to_write.age <- dat_age_at_arrival %>%
  rename(
    `Avg. Native Friends`=avg_friendships_to_curr_country_natives,
    `Avg. % of Friends Natives`=avg_share_frnds_natives,
    `Avg. % Produced DE`=avg_share_produced_german,
    `Avg. % Consumed DE`=avg_share_consumed_german) %>%
  gather(measure, value, -quarters_since_DE, -age_when_arrived) %>%
  mutate(measure=factor(measure, levels=outcome_order)) %>%
  filter(quarters_since_DE >= 0 & quarters_since_DE <= 12)

dat_to_write.age %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=age_when_arrived)) +
  geom_point() +
  facet_grid(rows = vars(measure), scales = "free") +
  theme_bw() +
  scale_color_brewer(palette = "Set2") +
  labs(x="Quarters in Germany", y="Value", col="Age When Arrived") +
  scale_x_continuous(breaks=1:12)

write_csv(dat_to_write.age, "output/integreation_2015_2016_cohort_by_age.csv")
ggsave("output/integreation_2015_2016_cohort_by_age.png", last_plot(), height=6.5, width=6.5)



##############################
### 3. Measures by Gender  ###
##############################

dat_gender <- read_csv("input/pipeline/integration_over_time_gender.csv")

dat_to_write.gender <- dat_gender %>%
  rename(
    `Avg. Native Friends`=avg_friendships_to_curr_country_natives,
    `Avg. % of Friends Natives`=avg_share_frnds_natives,
    `Avg. % Produced DE`=avg_share_produced_german,
    `Avg. % Consumed DE`=avg_share_consumed_german) %>%
  gather(measure, value, -quarters_since_DE, -gender) %>%
  mutate(measure=factor(measure, levels=outcome_order)) %>%
  filter(quarters_since_DE >= 0 & quarters_since_DE <= 12)

dat_to_write.gender %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=gender)) +
  geom_point() +
  facet_grid(rows = vars(measure), scales = "free") +
  theme_bw() +
  scale_color_brewer(palette = "Set2") +
  labs(x="Quarters in Germany", y="Value", col="Gender") +
  scale_x_continuous(breaks=1:12)

write_csv(dat_to_write.gender, "output/integreation_2015_2016_cohort_by_gender.csv")
ggsave("output/integreation_2015_2016_cohort_by_gender.png", last_plot(), height=6.5, width=6.5)



##################################
### 4. Measures by Gender/Age  ###
##################################

# In this step, we add series to the plots one-by-one
dat_age_gender <- read_csv("input/pipeline/integration_over_time_gender_age.csv")

dat_to_write.age_gender <- dat_age_gender %>%
  rename(
    `Avg. Native Friends`=avg_friendships_to_curr_country_natives,
    `Avg. % of Friends Natives`=avg_share_frnds_natives,
    `Avg. % Produced DE`=avg_share_produced_german,
    `Avg. % Consumed DE`=avg_share_consumed_german) %>%
  gather(measure, value, -quarters_since_DE, -age_gender_when_arrived) %>%
  mutate(measure=factor(measure, levels=outcome_order)) %>%
  filter(quarters_since_DE >= 0 & quarters_since_DE <= 12)

## Step 1
dat_to_write.age_gender %>%
  filter(
    measure == "Avg. Native Friends",
    str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 2] == "13-18") %>%
  mutate(gender = str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 1]) %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=age_gender_when_arrived, shape=age_gender_when_arrived)) +
  geom_point(size=3) +
  theme_bw() +
  labs(x="Quarters in Germany", y="Avg. Native Friends") +
  scale_colour_manual(name = "Gender, Age When Arrived",
                      values = brewer.pal(n = 8, name = "Set2")[c(1,2)]) +
  scale_shape_manual(name = "Gender, Age When Arrived",
                     values = c(19, 17)) +
  scale_x_continuous(breaks=1:12)

ggsave("output/integration_2015_2016_cohort_by_age_gender_STEP1.png", last_plot(), height=4, width=6.5)

## Step 2
dat_to_write.age_gender %>%
  filter(
    measure == "Avg. Native Friends",
    str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 2] %in% c("13-18", "19-30")) %>%
  mutate(gender = str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 1]) %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=age_gender_when_arrived, shape=age_gender_when_arrived)) +
  geom_point(size=3) +
  theme_bw() +
  labs(x="Quarters in Germany", y="Avg. Native Friends") +
  scale_colour_manual(name = "Gender, Age When Arrived",
                      values = brewer.pal(n = 8, name = "Set2")[c(1,3,2,4)]) +
  scale_shape_manual(name = "Gender, Age When Arrived",
                     values = c(19, 19, 17, 17)) +
  scale_x_continuous(breaks=1:12)

ggsave("output/integration_2015_2016_cohort_by_age_gender_STEP2.png", last_plot(), height=4, width=6.5)


## Step 3
dat_to_write.age_gender %>%
  filter(
    measure == "Avg. Native Friends",
    str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 2] %in% c("13-18", "19-30", "31-45")) %>%
  mutate(gender = str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 1]) %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=age_gender_when_arrived, shape=age_gender_when_arrived)) +
  geom_point(size=3) +
  theme_bw() +
  labs(x="Quarters in Germany", y="Avg. Native Friends") +
  scale_colour_manual(name = "Gender, Age When Arrived",
                      values = brewer.pal(n = 8, name = "Set2")[c(1,3,5,2,4,6)]) +
  scale_shape_manual(name = "Gender, Age When Arrived",
                     values = c(19, 19, 19, 17, 17, 17)) +
  scale_x_continuous(breaks=1:12)

ggsave("output/integration_2015_2016_cohort_by_age_gender_STEP3.png", last_plot(), height=4, width=6.5)

## Step 4
dat_to_write.age_gender %>%
  filter(
    measure == "Avg. Native Friends",
    str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 2] %in% c("13-18", "19-30", "31-45", "46+")) %>%
  mutate(gender = str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 1]) %>%
  ggplot(., aes(x=quarters_since_DE, y=value, col=age_gender_when_arrived, shape=age_gender_when_arrived)) +
  geom_point(size=3) +
  theme_bw() +
  labs(x="Quarters in Germany", y="Avg. Native Friends") +
  scale_colour_manual(name = "Gender, Age When Arrived",
                      values = brewer.pal(n = 8, name = "Set2")[c(1,3,5,7,2,4,6,8)]) +
  scale_shape_manual(name = "Gender, Age When Arrived",
                     values = c(19, 19, 19, 19, 17, 17, 17, 17)) +
  scale_x_continuous(breaks=1:12)

ggsave("output/integration_2015_2016_cohort_by_age_gender_STEP4.png", last_plot(), height=4, width=6.5)

## Overall
dat_to_write.age_gender %>%
    mutate(gender = str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 1]) %>%
    ggplot(., aes(x=quarters_since_DE, y=value, col=age_gender_when_arrived, shape=age_gender_when_arrived)) +
    geom_point() +
    facet_grid(rows = vars(measure), scales = "free") +
    theme_bw() +
    labs(x="Quarters in Germany", y="Value") +
    scale_colour_manual(name = "Gender, Age When Arrived",
                        values = brewer.pal(n = 8, name = "Set2")[c(1,3,5,7,2,4,6,8)]) +
    scale_shape_manual(name = "Gender, Age When Arrived",
                       values = c(19, 19, 19, 19, 17, 17, 17, 17)) +
    scale_x_continuous(breaks=1:12)

ggsave("output/integration_2015_2016_cohort_by_age_gender.png", last_plot(), height=6.5, width=6.5)


##############################
### 5. Final Summary Plots ###
##############################

# Function to extract a legend
# https://github.com/hadley/ggplot2/wiki/Share-a-legend-between-two-ggplot2-graphs
g_legend<-function(a.gplot){
  tmp <- ggplot_gtable(ggplot_build(a.gplot))
  leg <- which(sapply(tmp$grobs, function(x) x$name) == "guide-box")
  legend <- tmp$grobs[[leg]]
  return(legend)}

# Function for the baseline plot
make_baseline_plot <- function(dat){
  dat %>%
    ggplot(., aes(x=quarters_since_DE, y=value, col=measure)) +
    geom_point(size=2.5, shape=15) +
    theme_bw() +
    scale_color_manual(values=c("black")) +
    theme(legend.position = "none") +
    labs(x="Quarters in Germany", y="Value") +
    scale_x_continuous(breaks=1:12) +
    theme(axis.title=element_text(size=10),
          plot.title=element_text(size=12))
}

# Function for the heterogentiy by age/gender plot
make_hetero_plot <- function(dat){
  dat %>%
    mutate(gender = str_split(age_gender_when_arrived, ", ", simplify = TRUE)[, 1]) %>%
    ggplot(., aes(x=quarters_since_DE, y=value, col=age_gender_when_arrived, shape=age_gender_when_arrived)) +
    geom_point() +
    theme_bw() +
    labs(x="Quarters in Germany", y="Value") +
    scale_colour_manual(name = "Gender, Age When Arrived",
                        values = brewer.pal(n = 8, name = "Set2")[c(1,3,5,7,2,4,6,8)]) +
    scale_shape_manual(name = "Gender, Age When Arrived",
                       values = c(19, 19, 19, 19, 17, 17, 17, 17)) +
    scale_x_continuous(breaks=1:12) +
    theme(axis.title=element_text(size=10),
          plot.title=element_text(size=12))
}


# Final function that combines these other helpers to produce actual figure
make_final_figures <- function(dat, dat_hetero, outcomes){

  lims1 <- c(
    min(filter(dat_hetero, measure == outcomes[1])$value),
    max(filter(dat_hetero, measure == outcomes[1])$value))

  lims2 <- c(
    min(filter(dat_hetero, measure == outcomes[2])$value),
    max(filter(dat_hetero, measure == outcomes[2])$value))

  p1 <- make_baseline_plot(filter(dat, measure == outcomes[1])) +
    theme(legend.position = "none") +
    labs(title = outcomes[1]) + lims(y=lims1)

  p2 <- make_baseline_plot(filter(dat, measure == outcomes[2])) +
    theme(legend.position = "none") +
    labs(title = outcomes[2]) + lims(y=lims2)

  p3 <- make_hetero_plot(filter(dat_hetero, measure == outcomes[1])) +
    theme(legend.position = "none") + lims(y=lims1)

  p4 <- make_hetero_plot(filter(dat_hetero, measure == outcomes[2])) +
    theme(legend.position = "none") + lims(y=lims2)


  mylegend <- g_legend(make_hetero_plot(filter(dat_hetero, measure == outcomes[1])) +
                         theme(legend.position="bottom") +
                         guides(colour = guide_legend(title.position="top", title.hjust = 0.5),
                                shape = guide_legend(title.position="top", title.hjust = 0.5)))


  p <- grid.arrange(
    grobs = list(p1, p2, p3, p4, mylegend),
    widths = c(1, 1),
    layout_matrix = rbind(c(1, 2),
                          c(1, 2),
                          c(1, 2),
                          c(3, 4),
                          c(3, 4),
                          c(3, 4),
                          c(5, 5))
  )

}

# This produces Figure 1

plot_A <- make_final_figures(dat_to_write, dat_to_write.age_gender, outcomes_SPLIT_A)
ggsave("output/integration_2015_2016_cohort_SPLIT_A.png", plot_A, height=6.5, width=6.5)

plot_B <- make_final_figures(dat_to_write, dat_to_write.age_gender, outcomes_SPLIT_B)
ggsave("output/integration_2015_2016_cohort_SPLIT_B.png", plot_B, height=6.5, width=6.5)
